This notebook reproduces experiment 4 in Arvanitidis et al. (2017). We train a convolutional VAE on frames of a video and visualize random walks in the latent space. These walks can be computed using either the Euclidean or the Riemannian metric. Since the Riemannian metric also takes the generator's variance into account, the random walk using the Riemannian metric will avoid regions of high variance.
# Imports and setup of plotting library
%load_ext autoreload
%autoreload 2
%matplotlib inline
from copy import deepcopy
import numpy as np
import tensorflow as tf
from tensorflow.python.keras.datasets import mnist
from tensorflow.python.keras import Sequential, Model
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.layers import Dense, Input, Lambda, Conv2D
from tensorflow.python.keras.layers import Conv2DTranspose, Flatten, Reshape
from tensorflow.python.keras.constraints import NonNeg
from tensorflow.python.keras.initializers import RandomUniform
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import plotly.graph_objs as go
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
from src.vae import VAE
from src.rbf import RBFLayer
from src.videoio import get_frames, load_from_pngs
from src.plot import plot_latent_curve_iterations, plot_magnification_factor
# Set up plotly
init_notebook_mode(connected=True)
layout = go.Layout(
width=700,
height=500,
margin=go.Margin(l=60, r=60, b=40, t=20),
showlegend=False
)
config={'showLink': False}
# Make results completely repeatable
seed = 0
np.random.seed(seed)
tf.set_random_seed(seed)
debug = False
train_epochs = 10 if debug else 600
using the description in Appendix D in the paper. I have found subtracting the mean from the input data to cause slightly better reconstructions. Thus, the decoder's mean network has a tanh activation function in it's last layer. Apart from this, the VAE below matches exactly the description in the paper.
# Implementation details from Appendix D
input_shape = (64, 64, 3)
latent_dim = 2
l2_reg = tf.keras.regularizers.l2(1e-5)
# Create the encoder models
enc_input = Input(input_shape)
enc_shared = Sequential([
Conv2D(32, (3, 3), strides=(2, 2), activation='tanh', padding='same',
input_shape=input_shape, kernel_regularizer=l2_reg),
Conv2D(32, (3, 3), strides=(2, 2), activation='tanh', padding='same',
kernel_regularizer=l2_reg),
Flatten()
])
enc_mean = Sequential([
enc_shared,
Dense(1024, activation='tanh', kernel_regularizer=l2_reg),
Dense(2, activation='linear', kernel_regularizer=l2_reg)
])
enc_var = Sequential([
enc_shared,
Dense(1024, activation='tanh', kernel_regularizer=l2_reg),
Dense(2, activation='softplus', kernel_regularizer=l2_reg)
])
enc_mean = Model(enc_input, enc_mean(enc_input))
enc_var = Model(enc_input, enc_var(enc_input))
# Create the decoder models
dec_input = Input((latent_dim,))
dec_mean = Sequential([
Dense(1024, activation='tanh', kernel_regularizer=l2_reg),
Dense(16 * 16 * 3, activation='tanh', kernel_regularizer=l2_reg),
Reshape((16, 16, 3)),
Conv2DTranspose(32, (3, 3), strides=(2, 2), activation='tanh',
padding='same', kernel_regularizer=l2_reg),
Conv2DTranspose(32, (3, 3), strides=(2, 2), activation='tanh',
padding='same', kernel_regularizer=l2_reg),
Conv2DTranspose(3, (3, 3), strides=(1, 1), activation='tanh',
padding='same', kernel_regularizer=l2_reg),
Conv2D(3, (3, 3), strides=(1, 1), activation='tanh', padding='same',
kernel_regularizer=l2_reg)
])
# Build the RBF network
num_centers = 64
a = 2.0
rbf = RBFLayer([32, 32, 3], num_centers)
var_constraint = NonNeg()
dec_var = Sequential([
rbf,
Conv2DTranspose(1, (3, 3), strides=(2, 2), activation='linear',
padding='same', kernel_constraint=NonNeg(),
bias_constraint=NonNeg(),
kernel_initializer=RandomUniform(minval=0, maxval=0.05),
kernel_regularizer=l2_reg),
Conv2D(3, (3, 3), strides=(1, 1), activation='linear',
padding='same', kernel_constraint=var_constraint,
bias_constraint=var_constraint,
kernel_initializer=RandomUniform(minval=0, maxval=0.05),
kernel_regularizer=l2_reg),
])
dec_mean = Model(dec_input, dec_mean(dec_input))
dec_var = Model(dec_input, dec_var(dec_input))
vae = VAE(enc_mean, enc_var, dec_mean, dec_var, dec_stddev=1.)
and subtract the mean.
x_train = load_from_pngs('~/Desktop/trump-cut/')
# Shuffle the training data, but save the permutation for later
permutation = np.random.permutation(len(x_train))
x_train = x_train[permutation]
# Subtract the mean
x_mean = np.mean(x_train, axis=0)
x_train -= x_mean
plt.imshow(x_mean)
x_plot = x_train[10]
plt.imshow(x_plot + x_mean)
history = vae.model.fit(x_train,
epochs=train_epochs,
batch_size=32,
validation_split=0.1,
verbose=0)
# Plot the losses
data = [go.Scatter(y=history.history['loss'], name='Train Loss'),
go.Scatter(y=history.history['val_loss'], name='Validation Loss')]
plot_layout = deepcopy(layout)
plot_layout['xaxis'] = {'title': 'Epoch'}
plot_layout['yaxis'] = {'title': 'NELBO'}
plot_layout['showlegend'] = True
iplot(go.Figure(data=data, layout=plot_layout), config=config)
Training takes about 8 hours, so we can simply reload the trained VAE here. This is why the execution count is suddenly larger in the cells below.
if not debug:
vae.encoder.save('models/video-encoder.h5', include_optimizer=False)
vae.decoder.save('models/video-generator.h5', include_optimizer=False)
from src.vae import load_from
vae = load_from('models/video-encoder.h5', 'models/video-generator.h5')
rbf = vae.decoder.layers[2].layers[0]
Both in the paper and here, we see chains of latent points indicating that frames in a sequence end up next to each other in the latent space.
# Display a 2D plot of the classes in the latent space
sampled, encoded_mean, encoded_var = vae.encoder.predict(x_train)
# Plot
scatter_plot = go.Scatter(
x = encoded_mean[:, 0],
y = encoded_mean[:, 1],
mode = 'markers',
marker = {'color': 'orange'}
)
data = [scatter_plot]
iplot(go.Figure(data=data, layout=layout), config=config)
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
from moviepy.video.io.html_tools import ipython_display
from scipy.misc import imresize
# Display a sequence and it's reconstructed version side by side
seq_length = 1000
# Invert the permutation with np.argsort
seq_indices = np.argsort(permutation)[:seq_length]
sequence = x_train[seq_indices]
_, reconstructed, _ = vae.decoder.predict(encoded_mean[seq_indices])
frames = []
for i in range(len(sequence)):
frame = np.concatenate([sequence[i], reconstructed[i]], axis=1)
frame += np.concatenate([x_mean, x_mean], axis=1)
frame = np.clip(frame, 0, 1) * 255.0
# Scale from 64x64 per image to 256x256
frame = imresize(frame, 4.0, 'nearest')
frames.append(frame)
clip = ImageSequenceClip(sequence=frames, fps=30)
ipython_display(clip)
First, find the centers of the latent representations.
# Find the centers of the latent representations
kmeans_model = KMeans(n_clusters=num_centers, random_state=0)
kmeans_model = kmeans_model.fit(encoded_mean)
centers = kmeans_model.cluster_centers_
# Visualize the centers
center_plot = go.Scatter(
x = centers[:, 0],
y = centers[:, 1],
mode = 'markers',
marker = {'color': 'red'}
)
data = [scatter_plot, center_plot]
iplot(go.Figure(data=data, layout=layout), config=config)
# Cluster the latent representations
clustering = dict((c_i, []) for c_i in range(num_centers))
for z_i, c_i in zip(encoded_mean, kmeans_model.predict(encoded_mean)):
clustering[c_i].append(z_i)
bandwidths = []
for c_i, cluster in clustering.items():
if cluster:
diffs = np.array(cluster) - centers[c_i]
avg_dist = np.mean(np.linalg.norm(diffs, axis=1))
bandwidth = 0.5 / (a * avg_dist)**2
else:
bandwidth = 0
bandwidths.append(bandwidth)
bandwidths = np.array(bandwidths)
while keeping all other parameters of the VAE fixed, as described in the paper.
# Train the RBF
vae.recompile_for_var_training()
rbf_kernel = rbf.get_weights()[0]
rbf.set_weights([rbf_kernel, centers, bandwidths])
history = vae.model.fit(x_train,
epochs=1,
batch_size=32,
validation_split=0.1,
verbose=0)
# Extract the mean and std predictors
from src.util import wrap_model_in_float64
_, mean, var = vae.decoder.output
std = Lambda(tf.sqrt)(var)
dec_mean = Model(vae.decoder.input, Flatten()(mean))
dec_std = Model(vae.decoder.input, Flatten()(std))
dec_mean = wrap_model_in_float64(dec_mean)
dec_std = wrap_model_in_float64(dec_std)
axis_length = max(abs(encoded_mean.min()), encoded_mean.max()) + 10
heatmap_z1 = np.linspace(-axis_length, axis_length, 200 if not debug else 3)
heatmap_z2 = np.linspace(-axis_length, axis_length, 200 if not debug else 3)
heatmap = plot_magnification_factor(K.get_session(),
heatmap_z1,
heatmap_z2,
dec_mean,
dec_std,
additional_data=[scatter_plot],
layout=layout,
log_scale=True,
scale='hotcold')
Define the Riemannian metric and the Euclidean metric (which is the identity matrix).
# Let's take a random walk
from tqdm import tqdm
from src.util import get_metric_op, get_numerical_jacobian
session = K.get_session()
def jac_fun(output_tensor, input_tensor):
return get_numerical_jacobian(session, output_tensor, input_tensor)
# Build the riemannian function
point = tf.placeholder(tf.float64, [2])
metric_op = get_metric_op(point, dec_mean, dec_std, jac_fun=jac_fun)
def get_riemannian(position):
return session.run(metric_op, feed_dict={point: position})
def get_euclidean(position):
return np.eye(len(position))
def random_walk(metric_fun, num_steps=1000, step_size=1.):
position = np.array([0., 0.])
walk = [np.copy(position)]
for _ in tqdm(range(num_steps), 'Taking Random Walk'):
metric = metric_fun(position)
eigvals, eigvecs = np.linalg.eig(metric)
noise = np.random.randn(2)
v = (eigvecs * (eigvals ** -0.5)).dot(noise)
position += step_size * v
walk.append(np.copy(position))
return np.vstack(walk)
riemannian_walk = random_walk(get_riemannian)
euclidean_walk = random_walk(get_euclidean)
riemannian_plot = go.Scatter(
x=riemannian_walk[:, 0],
y=riemannian_walk[:, 1],
mode='lines',
line={'width': 1, 'color': 'green'}
)
euclidean_plot = go.Scatter(
x=euclidean_walk[:, 0],
y=euclidean_walk[:, 1],
mode='lines',
line={'width': 1, 'color': '#ff005a'}
)
data = [scatter_plot, euclidean_plot, riemannian_plot]
iplot(go.Figure(data=data, layout=layout), config=config)
at different steps of the random walks.
from src.plot import plot_images
# Visualize some steps in the random walks
steps = [0, 200, 300, 800, 900, 1000]
images = {}
for step in steps:
euclidean_position = euclidean_walk[step]
riemannian_position = riemannian_walk[step]
_, (euclidean_frame, riemannian_frame), _ = vae.decoder.predict(np.array([
euclidean_position,
riemannian_position
]))
euclidean_frame += x_mean
riemannian_frame += x_mean
euclidean_frame = np.clip(euclidean_frame, 0, 1) * 255.0
riemannian_frame = np.clip(riemannian_frame, 0, 1) * 255.0
# Scale from 64x64 per image to 256x256
euclidean_frame = imresize(euclidean_frame, 4.0, 'nearest')
riemannian_frame = imresize(riemannian_frame, 4.0, 'nearest')
images['step %d euclidean' % step] = euclidean_frame
images['step %d riemannian' % step] = riemannian_frame
plot_images(images, nrows=6, ncols=2)